Generative Adversarial Networks (GAN)


By Prof. Seungchul Lee
http://iai.postech.ac.kr/
Industrial AI Lab at POSTECH

Table of Contents

Source

  • CS231n: CNN for Visual Recognition

1. Discriminative Model v.s. Generative Model

  • Discriminative model




  • Cenerative model



2. Density Function Estimation

  • Probability
  • What if $x$ is actual images in the training data? At this point, $x$ can be represented as a (for example) $64\times 64 \times 3$ dimensional vector.
    • the following images are some realizations (samples) of $64\times 64 \times 3$ dimensional space
  • Probability density function estimation problem
  • If $P_{\text{model}}(x)$ can be estimated as close to $P_{\text{data}}(x)$, then data can be generated by sampling from $P_{\text{model}}(x)$.

    • Note: Kullback–Leibler Divergence is a kind of distance measure between two distributions
  • Learn determinstic transformation via a neural network
    • Start by sampling the code vector $z$ from a simple, fixed distribution such as a uniform distribution or a standard Gaussian $\mathcal{N}(0,I)$
    • Then this code vector is passed as input to a deterministic generator network $G$, which produces an output sample $x=G(z)$
    • This is how a neural network plays in a generative model (as a nonlinear mapping to a target probability density function)



- An example of a generator network which encodes a univariate distribution with two different modes

  • Generative model of high dimensional space
  • Generative model of images
    • learn a function which maps independent, normally-distributed $z$ values to whatever latent variables might be needed to the model, and then map those latent variables to $x$ (as images)
    • first few layers to map the normally distributed $z$ to the latent values
    • then, use later layers to map those latent values to an image



3. Generative Adversarial Networks (GAN)

  • In generative modeling, we'd like to train a network that models a distribution, such as a distribution over images.

  • GANs do not work with any explicit density function !

  • Instead, take game-theoretic approach

3.1. Adversarial Nets Framework

  • One way to judge the quality of the model is to sample from it.

  • Model to produce samples which are indistinguishable from the real data, as judged by a discriminator network whose job is to tell real from fake





  • The idea behind Generative Adversarial Networks (GANs): train two different networks


  • Discriminator network: try to distinguish between real and fake data


  • Generator network: try to produce realistic-looking samples to fool the discriminator network


3.2. Objective Function of GAN

  • Think about a logistic regression classifier (or cross entropy loss $(h(x),y)$)


$$\text{loss} = -y \log h(x) - (1-y) \log (1-h(x))$$

  • To train the discriminator


  • To train the generator


  • Non-Saturating Game when the generator is trained

    • Early in learning, when $G$ is poor, $D$ can reject samples with high confidence because they are clearly different from the training data. In this case, $\log(1-D(G(z)))$ saturates.

    • Rather than training $G$ to minimize $\log(1-D(G(z)))$ we can train $G$ to maximize $\log D(G(z))$. This objective function provides much stronger gradients early in learning.

3.3. Soving a MinMax Problem


Step 1: Fix $G$ and perform a gradient step to


$$\max_{D} E_{x \sim p_{\text{data}}(x)}\left[\log D(x)\right] + E_{x \sim p_{z}(z)}\left[\log (1-D(G(z)))\right]$$

Step 2: Fix $D$ and perform a gradient step to


$$\max_{G} E_{x \sim p_{z}(z)}\left[\log D(G(z))\right]$$

OR



Step 1: Fix $G$ and perform a gradient step to


$$\min_{D} E_{x \sim p_{\text{data}}(x)}\left[-\log D(x)\right] + E_{x \sim p_{z}(z)}\left[-\log (1-D(G(z)))\right]$$

Step 2: Fix $D$ and perform a gradient step to


$$\min_{G} E_{x \sim p_{z}(z)}\left[-\log D(G(z))\right]$$

4. GAN with MNIST

4.1. GAN Implementation

In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
In [2]:
mnist = tf.keras.datasets.mnist

(train_x, train_y), _ = mnist.load_data()
train_x = train_x[np.where(train_y == 2)]
train_x= train_x/255.0
train_x = train_x.reshape(-1, 784)

print('train_iamges :', train_x.shape)
train_iamges : (5958, 784)
In [3]:
generator = tf.keras.models.Sequential([
    tf.keras.layers.Dense(units = 256, input_dim = 100, activation = 'relu'),
    tf.keras.layers.Dense(units = 784, activation = 'sigmoid')    
])
In [4]:
discriminator = tf.keras.models.Sequential([
    tf.keras.layers.Dense(units = 256, input_dim = 784, activation = 'relu'),
    tf.keras.layers.Dense(units = 1, activation = 'sigmoid'),
])
In [5]:
discriminator.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 0.0001), 
                      loss = 'binary_crossentropy')
In [6]:
combined_input = tf.keras.layers.Input(shape = (100,))
generated = generator(combined_input)
discriminator.trainable = False
combined_output = discriminator(generated)

combined = tf.keras.models.Model(inputs = combined_input, outputs = combined_output)
In [7]:
combined.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 0.0002), 
                 loss = 'binary_crossentropy')
In [8]:
def make_noise(samples):
    return np.random.normal(0, 1, [samples, 100])
In [9]:
def plot_generated_images(generator, samples = 3):
    
    noise = make_noise(samples)
    
    generated_images = generator.predict(noise)
    generated_images = generated_images.reshape(samples, 28, 28)
    
    for i in range(samples):
        plt.subplot(1, samples, i+1)
        plt.imshow(generated_images[i], 'gray', interpolation = 'nearest')
        plt.axis('off')
        plt.tight_layout()
    plt.show()

Step 1: Fix $G$ and perform a gradient step to

$$\min_{D} E_{x \sim p_{\text{data}}(x)}\left[-\log D(x)\right] + E_{x \sim p_{z}(z)}\left[-\log (1-D(G(z)))\right]$$

Step 2: Fix $D$ and perform a gradient step to

$$\min_{G} E_{x \sim p_{z}(z)}\left[-\log D(G(z))\right]$$
In [10]:
n_iter = 20000
batch_size = 100

fake = np.zeros(batch_size)
real = np.ones(batch_size)

for i in range(n_iter):
        
    # Train Discriminator
    noise = make_noise(batch_size)
    generated_images = generator.predict(noise)

    idx = np.random.randint(0, train_x.shape[0], batch_size)
    real_images = train_x[idx]

    D_loss_real = discriminator.train_on_batch(real_images, real)
    D_loss_fake = discriminator.train_on_batch(generated_images, fake)
    D_loss = D_loss_real + D_loss_fake
    
    # Train Generator
    noise = make_noise(batch_size)    
    G_loss = combined.train_on_batch(noise, real)
    
    if i % 5000 == 0:
        
        print('Discriminator Loss: ', D_loss)
        print('Generator Loss: ', G_loss)

        plot_generated_images(generator)
Discriminator Loss:  1.4117062091827393
Generator Loss:  0.9805243611335754
Discriminator Loss:  0.27343617379665375
Generator Loss:  2.3575124740600586
Discriminator Loss:  0.4969727545976639
Generator Loss:  1.827857494354248
Discriminator Loss:  0.4474167227745056
Generator Loss:  1.7708817720413208

4.2. After Training

  • After training, use the generator network to generate new data


In [11]:
plot_generated_images(generator)

5. Conditional GAN

  • In an unconditioned generative model, there is no control on modes of the data being generated.
  • In the Conditional GAN (CGAN), the generator learns to generate a fake sample with a specific condition or characteristics (such as a label associated with an image or more detailed tag) rather than a generic sample from unknown noise distribution.




  • Simple modification to the original GAN framework that conditions the model on additional information for better multi-modal learning
  • Many practical applications of GANs when we have explicit supervision available
In [12]:
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
In [13]:
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train/255.0 , x_test/255.0
x_train, x_test = x_train.reshape(-1,784), x_test.reshape(-1,784)
y_train, y_test = y_train.reshape(-1, 1), y_test.reshape(-1, 1)

print('x_train: ', x_train.shape)
print('x_test: ', x_test.shape)
print('y_train: ', y_train.shape)
print('y_test: ', y_test.shape)
x_train:  (60000, 784)
x_test:  (10000, 784)
y_train:  (60000, 1)
y_test:  (10000, 1)
In [14]:
generator_model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(units = 256, input_dim = 138, activation = 'relu'),
    tf.keras.layers.Dense(units = 784, activation = 'sigmoid')
])

noise = tf.keras.layers.Input(shape = (128,))
label = tf.keras.layers.Input(shape = (1,))
label_onehot = tf.keras.layers.CategoryEncoding(10, output_mode='one_hot')(label)

model_input = tf.keras.layers.concatenate([noise, label_onehot], axis = 1)
generated_image = generator_model(model_input)

generator = tf.keras.models.Model([noise, label], generated_image)
In [15]:
generator.summary()
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_3 (InputLayer)            [(None, 1)]          0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 128)]        0                                            
__________________________________________________________________________________________________
category_encoding (CategoryEnco (None, 10)           0           input_3[0][0]                    
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 138)          0           input_2[0][0]                    
                                                                 category_encoding[0][0]          
__________________________________________________________________________________________________
sequential_2 (Sequential)       (None, 784)          237072      concatenate[0][0]                
==================================================================================================
Total params: 237,072
Trainable params: 237,072
Non-trainable params: 0
__________________________________________________________________________________________________
In [16]:
discriminator_model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(units = 256,input_dim = 794, activation = 'relu'),
    tf.keras.layers.Dense(units = 1, activation = 'sigmoid')
])

input_image = tf.keras.layers.Input(shape = (784,))
label = tf.keras.layers.Input(shape = (1,))
label_onehot = tf.keras.layers.CategoryEncoding(10, output_mode='one_hot')(label)

model_input = tf.keras.layers.concatenate([input_image, label_onehot], axis = 1)
validity = discriminator_model(model_input)

discriminator = tf.keras.models.Model([input_image, label], validity)
In [17]:
optim_d = tf.keras.optimizers.Adam(learning_rate = 0.0002)
In [18]:
discriminator.compile(loss = ['binary_crossentropy'], 
                      optimizer = optim_d)
In [19]:
discriminator.summary()
Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_5 (InputLayer)            [(None, 1)]          0                                            
__________________________________________________________________________________________________
input_4 (InputLayer)            [(None, 784)]        0                                            
__________________________________________________________________________________________________
category_encoding_1 (CategoryEn (None, 10)           0           input_5[0][0]                    
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 794)          0           input_4[0][0]                    
                                                                 category_encoding_1[0][0]        
__________________________________________________________________________________________________
sequential_3 (Sequential)       (None, 1)            203777      concatenate_1[0][0]              
==================================================================================================
Total params: 203,777
Trainable params: 203,777
Non-trainable params: 0
__________________________________________________________________________________________________
In [20]:
noise = tf.keras.layers.Input(shape = (128,))
label = tf.keras.layers.Input(shape = (1,))
generated_image = generator([noise, label])
discriminator.trainable = False
validity = discriminator([generated_image, label])

combined = tf.keras.models.Model([noise, label], validity)
In [21]:
optim_combined = tf.keras.optimizers.Adam(learning_rate = 0.0002)
In [22]:
combined.compile(loss = ['binary_crossentropy'], 
                 optimizer = optim_combined)
In [23]:
combined.summary()
Model: "model_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_6 (InputLayer)            [(None, 128)]        0                                            
__________________________________________________________________________________________________
input_7 (InputLayer)            [(None, 1)]          0                                            
__________________________________________________________________________________________________
model_1 (Functional)            (None, 784)          237072      input_6[0][0]                    
                                                                 input_7[0][0]                    
__________________________________________________________________________________________________
model_2 (Functional)            (None, 1)            203777      model_1[0][0]                    
                                                                 input_7[0][0]                    
==================================================================================================
Total params: 440,849
Trainable params: 237,072
Non-trainable params: 203,777
__________________________________________________________________________________________________
In [24]:
def create_noise(samples):
    return np.random.normal(0, 1, [samples, 128])
In [25]:
def plot_generated_images(generator):

    noise = create_noise(10)
    label = np.arange(0, 10).reshape(-1, 1)

    generated_images = generator.predict([noise, label])

    plt.figure(figsize = (90, 10))
    for i in range(generated_images.shape[0]):
        plt.subplot(1, 10, i + 1)
        plt.imshow(generated_images[i].reshape((28, 28)), 'gray', interpolation = 'nearest')
        plt.title('Digit: {}'.format(i), fontsize = 75)
        plt.axis('off')
    plt.show()
In [26]:
n_iter = 100000
batch_size = 100

valid = np.ones(batch_size)
fake = np.zeros(batch_size)
for i in range(n_iter):
        
    # Train Discriminator
    idx = np.random.randint(0, x_train.shape[0], batch_size)
    real_images, labels = x_train[idx], y_train[idx]
    
    noise = create_noise(batch_size)
    generated_images = generator.predict([noise,labels])
    
    d_loss_real = discriminator.train_on_batch([real_images, labels], valid)
    d_loss_fake = discriminator.train_on_batch([generated_images, labels], fake)
    d_loss = d_loss_real + d_loss_fake
    
    # Train Generator
    noise= create_noise(batch_size)
    labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)

    g_loss = combined.train_on_batch([noise, labels], valid)
    
    if i % 5000 == 0:

        print('Discriminator Loss: ', d_loss)
        print('Generator Loss: ', g_loss)

        plot_generated_images(generator)
Discriminator Loss:  2.1221539974212646
Generator Loss:  0.3581492602825165
Discriminator Loss:  0.012442981824278831
Generator Loss:  5.603395938873291
Discriminator Loss:  0.09481578692793846
Generator Loss:  4.591771125793457
Discriminator Loss:  0.08484158292412758
Generator Loss:  4.946230888366699
Discriminator Loss:  0.17704027891159058
Generator Loss:  3.9330806732177734
Discriminator Loss:  0.8256005346775055
Generator Loss:  2.618168354034424
Discriminator Loss:  0.5448089241981506
Generator Loss:  2.8494226932525635
Discriminator Loss:  0.6805634498596191
Generator Loss:  2.3530080318450928
Discriminator Loss:  0.8277974426746368
Generator Loss:  2.0937836170196533
Discriminator Loss:  0.8780734539031982
Generator Loss:  1.734046459197998
Discriminator Loss:  4.7912304401397705
Generator Loss:  0.6568551659584045
Discriminator Loss:  15.288216590881348
Generator Loss:  0.1776898205280304
Discriminator Loss:  0.7976755499839783
Generator Loss:  1.7534655332565308
Discriminator Loss:  0.8163438141345978
Generator Loss:  1.931523084640503
Discriminator Loss:  0.9229690134525299
Generator Loss:  1.693101167678833
Discriminator Loss:  0.8357776999473572
Generator Loss:  1.6274802684783936
Discriminator Loss:  1.3114712238311768
Generator Loss:  1.4318947792053223
Discriminator Loss:  0.9433916509151459
Generator Loss:  1.8737053871154785
Discriminator Loss:  0.8729234039783478
Generator Loss:  1.9886713027954102
Discriminator Loss:  0.4301510900259018
Generator Loss:  2.568551540374756

Generate fake MNIST images by CGAN

In [27]:
plot_generated_images(generator)

6. Other Tutorials

In [28]:
%%html
<center><iframe src="https://www.youtube.com/embed/9JpdAg6uMXs?rel=0" 
width="560" height="315" frameborder="0" allowfullscreen></iframe></center>
  • CS231n: CNN for Visual Recognition
In [29]:
%%html
<center><iframe src="https://www.youtube.com/embed/5WoItGTWV54?rel=0" 
width="560" height="315" frameborder="0" allowfullscreen></iframe></center>

MIT by Aaron Courville

In [30]:
%%html
<center><iframe src="https://www.youtube.com/embed/JVb54xhEw6Y?rel=0" 
width="560" height="315" frameborder="0" allowfullscreen></iframe></center>

Univ. of Wateloo By Ali Ghodsi

In [31]:
%%html
<center><iframe src="https://www.youtube.com/embed/7G4_Y5rsvi8?rel=0" 
width="560" height="315" frameborder="0" allowfullscreen></iframe></center>
In [32]:
%%html
<center><iframe src="https://www.youtube.com/embed/odpjk7_tGY0?rel=0" 
width="560" height="315" frameborder="0" allowfullscreen></iframe></center>
In [33]:
%%javascript
$.getScript('https://kmahelona.github.io/ipython_notebook_goodies/ipython_notebook_toc.js')